{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: Simple Agentic RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wikipedia\n",
    "from markdownify import markdownify\n",
    "from tensorzero import AsyncTensorZeroGateway, ToolCall, ToolResult"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Kill the RAG process after a certain number of inferences to prevent infinite loops.\n",
    "\n",
    "MAX_INFERENCES = 20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorZero Client\n",
    "\n",
    "We initialize an embedded TensorZero client with our configuration file.\n",
    "\n",
    "To keep things minimal in this example, we don't set up observability with ClickHouse.\n",
    "\n",
    "See the [Quick Start](https://www.tensorzero.com/docs/quickstart) for a simple example that includes observability and the UI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = await AsyncTensorZeroGateway.build_embedded(\n",
    "    config_file=\"config/tensorzero.toml\",\n",
    "    # clickhouse_url=\"...\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tools\n",
    "\n",
    "We define the tools that will be used by the model.\n",
    "\n",
    "Here, we have a tool for searching Wikipedia and a tool for loading a Wikipedia page.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_wikipedia(tool_call: ToolCall) -> ToolResult:\n",
    "    \"\"\"\n",
    "    Searches Wikipedia for a given query and returns a list of search results.\n",
    "\n",
    "    Args:\n",
    "        tool_call (ToolCall): A tool call object containing the search query in its arguments.\n",
    "            Expected arguments: {\"query\": str}\n",
    "\n",
    "    Returns:\n",
    "        ToolResult: A tool result containing the newline-separated list of Wikipedia search results.\n",
    "            The result field contains the search results as a string.\n",
    "    \"\"\"\n",
    "    search_wikipedia_result = \"\\n\".join(wikipedia.search(tool_call.arguments[\"query\"]))\n",
    "\n",
    "    return ToolResult(\n",
    "        name=\"search_wikipedia\",\n",
    "        id=tool_call.id,\n",
    "        result=search_wikipedia_result,\n",
    "    )\n",
    "\n",
    "\n",
    "def load_wikipedia_page(tool_call: ToolCall) -> ToolResult:\n",
    "    \"\"\"\n",
    "    Loads and formats the content of a Wikipedia page.\n",
    "\n",
    "    Args:\n",
    "        tool_call (ToolCall): A tool call object containing the page title in its arguments.\n",
    "            Expected arguments: {\"title\": str}\n",
    "\n",
    "    Returns:\n",
    "        ToolResult: A tool result containing the formatted Wikipedia page content.\n",
    "            The result field contains the page URL and content in Markdown format.\n",
    "            If the page is not found or there's a disambiguation error, returns an error message.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        page = wikipedia.page(tool_call.arguments[\"title\"])\n",
    "        # Preprocess result by converting the HTML content to Markdown to reduce token usage\n",
    "        page_markdown = markdownify(page.html())\n",
    "        load_wikipedia_page_result = f\"# URL\\n\\n{page.url}\\n\\n# CONTENT\\n\\n{page_markdown}\"\n",
    "    except wikipedia.exceptions.PageError:\n",
    "        load_wikipedia_page_result = f\"ERROR: page '{tool_call.arguments['title']}' not found.\"\n",
    "    except wikipedia.exceptions.DisambiguationError as e:\n",
    "        load_wikipedia_page_result = f\"ERROR: disambiguation error for '{tool_call.arguments['title']}': {e}\"\n",
    "\n",
    "    return ToolResult(\n",
    "        name=\"load_wikipedia_page\",\n",
    "        id=tool_call.id,\n",
    "        result=load_wikipedia_page_result,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agentic RAG\n",
    "\n",
    "Here we define the function that will be used to ask a question to the multi-hop retrieval agent.\n",
    "\n",
    "The function takes a question and launches a multi-hop retrieval process.\n",
    "The agent will make a number of tool calls to search for information and answer the question.\n",
    "\n",
    "The function will return the answer to the question.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def ask_question(question: str, verbose: bool = False):\n",
    "    \"\"\"\n",
    "    Asks a question to the multi-hop retrieval agent and returns the answer.\n",
    "\n",
    "    Args:\n",
    "        question (str): The question to ask the agent.\n",
    "        verbose (bool, optional): Whether to print verbose output. Defaults to False.\n",
    "\n",
    "    Returns:\n",
    "        str: The answer to the question.\n",
    "    \"\"\"\n",
    "    # Initialize the message history with the user's question\n",
    "    messages = [{\"role\": \"user\", \"content\": question}]\n",
    "\n",
    "    # The episode ID is used to track the agent's progress (`None` until the first inference)\n",
    "    episode_id = None\n",
    "\n",
    "    for _ in range(MAX_INFERENCES):\n",
    "        print()\n",
    "        response = await t0.inference(\n",
    "            function_name=\"multi_hop_rag_agent\",\n",
    "            input={\"messages\": messages},\n",
    "            episode_id=episode_id,\n",
    "        )\n",
    "\n",
    "        # Append the assistant's response to the messages\n",
    "        messages.append({\"role\": \"assistant\", \"content\": response.content})\n",
    "\n",
    "        # Update the episode ID\n",
    "        episode_id = response.episode_id\n",
    "\n",
    "        # Start constructing the tool call results\n",
    "        output_content_blocks = []\n",
    "\n",
    "        for content_block in response.content:\n",
    "            if isinstance(content_block, ToolCall):\n",
    "                if verbose:\n",
    "                    print(f\"[Tool Call] {content_block.name}: {content_block.arguments}\")\n",
    "\n",
    "                if content_block.name is None or content_block.arguments is None:\n",
    "                    output_content_blocks.append(\n",
    "                        ToolResult(\n",
    "                            name=content_block.raw_name,\n",
    "                            id=content_block.id,\n",
    "                            result=\"ERROR: invalid tool call\",\n",
    "                        )\n",
    "                    )\n",
    "                elif content_block.name == \"search_wikipedia\":\n",
    "                    output_content_blocks.append(search_wikipedia(content_block))\n",
    "                elif content_block.name == \"load_wikipedia_page\":\n",
    "                    output_content_blocks.append(load_wikipedia_page(content_block))\n",
    "                elif content_block.name == \"think\":\n",
    "                    # The `think` tool is just used to plan the next steps, and there's no actual tool to call.\n",
    "                    # Some providers like OpenAI require a tool result, so we'll provide an empty string.\n",
    "                    output_content_blocks.append(\n",
    "                        ToolResult(\n",
    "                            name=\"think\",\n",
    "                            id=content_block.id,\n",
    "                            result=\"\",\n",
    "                        )\n",
    "                    )\n",
    "                elif content_block.name == \"answer_question\":\n",
    "                    return content_block.arguments[\"answer\"]\n",
    "            else:\n",
    "                # We don't need to do anything with other content blocks.\n",
    "                print(f\"[Other Content Block] {content_block}\")\n",
    "\n",
    "        messages.append({\"role\": \"user\", \"content\": output_content_blocks})\n",
    "    else:\n",
    "        # In a production setting, the model could attempt to generate an answer using available information\n",
    "        # when the search process is stopped; here, we simply throw an exception.\n",
    "        raise Exception(f\"Failed to answer question after {MAX_INFERENCES} inferences.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's try it out!\n",
    "\n",
    "Let's try our RAG agent on a few questions.\n",
    "\n",
    "The questions are fairly challenging. \n",
    "We present a rough research path that the agent can take to answer the question.\n",
    "GPT-4o Mini often gets them right, it's not always reliable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await ask_question(\n",
    "    \"What is a common dish in the hometown of the scientist that won the Nobel Prize for the discovery of the positron?\",\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "# Expected Answer: Nobel Prize for the discovery of the positron -> Carl D. Anderson -> New York City -> a popular NYC dish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await ask_question(\n",
    "    \"What company developed the popular Chinese video game voiced by the same voice actor that voiced a wizard in the anime Konosuba?\",\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "# Expected Answer: Konosuba's wizard -> Megumin -> voiced by Rie Takahashi -> Chinese video game -> Genshin Impact -> developed by HoYoverse (miHoYo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await ask_question(\n",
    "    \"What is the national flower of the country where the mathematician who proved Fermat's Last Theorem was born?\",\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "# Expected Answer: Fermat's Last Theorem -> Andrew Wiles -> United Kingdom -> national flower -> Tudor rose (red rose)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
